from unittest.mock import MagicMock, Mock

import pytest
import torch
from pytest_mock import MockerFixture
from vllm.config import CacheConfig, VllmConfig

from tests.ut.base import PytestBase
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer


class TestTorchairMtpProposer(PytestBase):

    @pytest.fixture
    def setup_torchair_mtp_proposer(self, mocker: MockerFixture):
        vllm_config = MagicMock(spec=VllmConfig)
        vllm_config.device_config = MagicMock()
        vllm_config.device_config.device = torch.device("cpu")
        vllm_config.speculative_config = MagicMock()
        vllm_config.speculative_config.draft_model_config = MagicMock()
        vllm_config.speculative_config.draft_model_config.dtype = torch.float16
        vllm_config.speculative_config.method = "mtp"
        vllm_config.speculative_config.num_speculative_tokens = 5
        vllm_config.load_config = MagicMock()
        cache_config = CacheConfig(block_size=16)
        vllm_config.cache_config = cache_config
        vllm_config.scheduler_config = MagicMock(max_num_batched_tokens=1024,
                                                 max_num_seqs=64)

        device = torch.device("cpu")
        runner = MagicMock()
        runner.pcp_size = 1
        runner.dcp_size = 1
        runner.pcp_rank = 0
        runner.max_num_tokens = 1024
        runner.max_num_reqs = 10
        runner._use_aclgraph.return_value = True

        mocker.patch(
            "vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__",
            return_value=None)
        mock_set_default_dtype = mocker.patch(
            'vllm.utils.torch_utils.set_default_torch_dtype')
        mock_set_default_dtype.return_value.__enter__.return_value = None

        mock_model_loader = MagicMock()
        mocker.patch("vllm.model_executor.model_loader.get_model_loader",
                     return_value=mock_model_loader)
        mock_layers = {
            "target_attn_layer_1": Mock(),
            "draft_attn_layer_2": Mock()
        }
        mocker.patch("vllm.config.get_layers_from_vllm_config",
                     return_value=mock_layers)
        mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
        mock_set_current.return_value.__enter__.return_value = None
        mock_torchair_deepseek_mtp = MagicMock()
        mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
        mocker.patch(
            "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
            return_value=mock_torchair_deepseek_mtp)
        mocker.patch(
            "vllm.model_executor.model_loader.utils.process_weights_after_loading"
        )

        proposer = TorchairMtpProposer(vllm_config, device, runner)
        proposer.vllm_config = vllm_config
        proposer.device = device
        proposer.runner = runner
        proposer.speculative_config = vllm_config.speculative_config
        proposer.draft_model_config = vllm_config.speculative_config.draft_model_config
        proposer.method = vllm_config.speculative_config.method

        return proposer, mock_model_loader, mock_torchair_deepseek_mtp

    def test_init(self, setup_torchair_mtp_proposer):
        proposer, _, _, = setup_torchair_mtp_proposer
        assert isinstance(proposer, TorchairMtpProposer)
